# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# 
import torch
import torch.nn as nn
import torch.optim as optim

from torchtext.data import Field, RawField, BucketIterator,TabularDataset,Dataset,Example
from torch.optim.lr_scheduler import LambdaLR

import torch.nn.functional as F
import spacy

from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import pairwise
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit
from prg import prg

import random
import math
import os
import time,json
import pdb, pickle

from modules.transformer_tree_model import *
from data_utils.Optim import *
from data_utils.train_vul import *

import data_utils.Constants as Constants
from data_utils.data_utils import *
import config
import argparse
import dgl

def preprocessing_batch(max_nodes_num, graphs, edges, device):
    batch_g = []
    for graph, edge in zip(graphs, edges):
        g = dgl.DGLGraph((edge[0], edge[2]))
        g.edata['type'] = graph.edata['type']
        g.add_nodes(max_nodes_num - g.number_of_nodes())
        batch_g.append(g)

    batch_graph = dgl.batch(batch_g)
    return batch_graph.to(device)

def main():
    title = 'vulnerability-detection'
    argParser = config.get_arg_parser(title)
    args = argParser.parse_args()
    if not os.path.exists(args.cache_path):
        os.makedirs(args.cache_path)
    
    max_len_src = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    with open(os.path.join(args.dataset_dir, 'vocab_asm.pkl'), 'rb') as f:
        vocab_asm = pickle.load(f)
    with open(os.path.join(args.dataset_dir, 'dataset_asm.pkl'), 'rb') as f:
        dataset_asm = pickle.load(f)
    with open(os.path.join(args.dataset_dir, 'tgt_asm.pkl'), 'rb') as f:
        tgt_asm = pickle.load(f)

    SEED=1234
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    exp_list = []
    Graph = RawField()
    EDGE = RawField()
    Nodes_num = RawField()
    TRG = RawField()

    for i in range(0, len(dataset_asm)):
        nodes_asm,edges_asm = dataset_asm[i]
        if len(nodes_asm) > args.max_tolerate_len:
            continue
        g = dgl.DGLGraph((edges_asm[0],edges_asm[2]))
        src_len = len(nodes_asm)
        idmap = range(0, src_len)
        g.ndata['node_id'] = torch.tensor(idmap, dtype=torch.long)
        g.ndata['annotation'] = torch.tensor(nodes_asm, dtype=torch.long)
        g.edata['type'] = torch.tensor(edges_asm[1]) 

        tgt = tgt_asm[i]
        exp = Example.fromlist([g, edges_asm, tgt, src_len],fields =[('graph',Graph), ('edge', EDGE), ('trg', TRG), ('nodes_num',Nodes_num)] )
        exp_list.append(exp)

    data_sets = Dataset(exp_list, fields = [('graph',Graph), ('edge',EDGE), ('trg', TRG), ('nodes_num', Nodes_num)])
    trn, tst, vld = data_sets.split([0.9,0.08,0.02])

    max_len_src = args.max_tolerate_len

    print("Number of training examples: %d" % (len(trn.examples)))
    print("Number of validation examples: %d" % (len(vld.examples)))
    print("Number of testing examples: %d" % (len(tst.examples)))

    args.summary = TrainingSummaryWriter(args.log_dir)
    train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
        (trn, vld, tst),
        batch_size = args.batch_size,
        sort_key = None,
        sort_within_batch = False,
        sort = False,
        device = device)

    gnn = Graph_NN( annotation_size = len(vocab_asm),
                        out_feats = args.hid_dim,
                        n_steps = args.n_gnn_layers,
                        device = device,
                        gnn_type='ggnn',
                        tok_embedding=2,
                        residual=False
                        )

    enc = Encoder(
                  None,
                  args.hid_dim,
                  args.n_layers,
                  args.n_heads,
                  args.pf_dim,
                  args.dropout,
                  device,
                  mem_dim=args.mem_dim,
                  embedding_flag = args.embedding_flag,
                  max_length = max_len_src)

    SRC_PAD_IDX = 0 
    model = VUL_DETECT_ASM_Model(gnn, enc, SRC_PAD_IDX, device, args.hid_dim, 2).to(device)

    model.apply(initialize_weights)

    criterion = torch.nn.CrossEntropyLoss() if args.one_hot_label else torch.nn.BCELoss()
    optimizer = NoamOpt(args.hid_dim, args.lr_ratio, args.warmup, \
                torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    criterion.to(device)
    best_val = None
    best_epoch = 0

    print("start training")

    best_valid_loss = float('inf')
    if args.training and not args.eval:
        for epoch in range(args.epoch_num):
            all_preds = []
            all_labels = []
            all_loss = []
            for i, batch in init_tqdm(enumerate(train_iterator), 'train' , log=args.log_dir):
                batch_graph_tmp = preprocessing_batch(max(batch.nodes_num), batch.graph, batch.edge, device)
                batch_graph = dgl.batch(batch.graph).to(device)
                labels = torch.tensor(batch.trg)
                loss, preds = train_eval(batch_graph, labels, batch_graph_tmp, model, device, optimizer, criterion, train=True)

                all_preds += [preds]
                all_labels += [labels]
                all_loss += [loss]

            all_preds = torch.cat(all_preds, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            metrics = report(all_labels, all_preds)
            loss = sum(all_loss).item()/len(all_loss)

            print('==> Epoch {}, Train Loss {:.4f}\t'.format(epoch, loss) + report_to_str(metrics, keys=True))

            all_preds = []
            all_labels = []
            all_loss = []
            for i, batch in enumerate(valid_iterator):
                batch_graph_tmp = preprocessing_batch(max(batch.nodes_num), batch.graph, batch.edge, device)
                batch_graph = dgl.batch(batch.graph)
                labels = torch.tensor(batch.trg)
                loss, preds = train_eval(batch_graph, labels, batch_graph_tmp, model, device, optimizer, criterion, train=False)
                all_preds += [preds]
                all_labels += [labels]
                all_loss += [loss]

            all_preds = torch.cat(all_preds, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            metrics = report(all_labels, all_preds)
            loss = sum(all_loss).item() / len(all_loss)

            print('==> Epoch {}, Valid Loss {:.4f}\t'.format(epoch, loss) + report_to_str(metrics, keys=True))
            if loss < best_valid_loss and (args.checkpoint_path is not None):
                best_valid_loss = loss
                torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'model_vul_detection.pt'))

    # all_preds = []
    # all_labels = []
    # all_loss = []
    # model.load_state_dict(torch.load(os.path.join(args.checkpoint_path, 'model_vul_detection.pt')))
    # model = model.to(device)
    # for i, batch in enumerate(test_iterator):
    #     batch_graph_tmp = preprocessing_batch(max(batch.nodes_num), batch.graph, batch.edge, device)
    #     batch_graph = dgl.batch(batch.graph)
    #     labels = torch.tensor(batch.trg)
    #     loss, preds = train_eval(batch_graph, labels, batch_graph_tmp, model, device, optimizer, criterion, train=False)
    #     all_preds += [preds]
    #     all_labels += [labels]
    #     all_loss += [loss]

    # all_preds = torch.cat(all_preds, dim=0)
    # all_labels = torch.cat(all_labels, dim=0)
    # metrics = report(all_labels, all_preds)
    # loss = sum(all_loss).item() / len(all_loss)

    # print('==> Epoch {},  Test Loss {:.4f}\t'.format(epoch, loss) + report_to_str(metrics, keys=True))

if __name__ == "__main__":
	main()
